{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Creating a Lower Limb Panoramic X-ray <a href=\"https://mybinder.org/v2/gh/InsightSoftwareConsortium/SimpleITK-Notebooks/master?filepath=Python%2F69_x-ray-panorama.ipynb\"><img style=\"float: right;\" src=\"https://mybinder.org/badge_logo.svg\"></a>\n",
    "\n",
    "\n",
    "Measurement of knee alignment is useful for diagnosis of arthritic conditions and for planning and evaluation of surgical interventions. Alignment is measured by the hip-knee-ankle ($HKA$) angle in standing, load bearing, x-ray images. The angle is defined by the femoral and tibial mechanical axes. The femoral axis is defined by the center of the femur head and the mid condylar point. The tibial axis is defined by the center of the tibial plateau to the center of the tibial plafond. \n",
    "\n",
    "\n",
    "\n",
    "<figure>\n",
    "  <img src=\"hkaAngle.png\", style=\"width:80px\"/>\n",
    "  <figcaption style=\"text-align:center\"> Hip-Knee-Ankle angle defined by the femoral mechanical axis (solid red line with dashed extension), and tibial mechanical axis (solid blue line).</figcaption>\n",
    "</figure> \n",
    "\n",
    "\n",
    "The three stances defined by the $HKA$ angle are:\n",
    " 1. Neutral alignment, $HKA=0^o$.\n",
    " 2. Varus, bow-legged, $HKA<0^o$.\n",
    " 3. Valgus, knock-kneed, $HKA>0^o$.\n",
    "\n",
    "For additional information see:\n",
    "1. T. D. Cooke et al., \"[Frontal plane knee alignment: a call for standardized measurement](https://www.ncbi.nlm.nih.gov/pubmed/17787049)\", J Rheumatol. 2007.\n",
    "2. A. F. Kamath et al., \"[What is Varus or Valgus Knee Alignment?: A Call for a Uniform Radiographic Classification](https://www.ncbi.nlm.nih.gov/pubmed/20361279)\", Clin Orthop Relat Res. 2010.\n",
    "\n",
    "For a robust estimate of the $HKA$ angle we would like to use a single image that contains the anatomy from the femoral head down to the ankle. Acquisition of such an image with standard x-ray imaging devices is not possible. It is achievable by acquiring multiple partially overlapping images and aligning, registering, them to the same coordinate system. The subject of this notebook. \n",
    "\n",
    "This notebook is based in part on the work described in: \"A marker-free registration method for standing X-ray panorama reconstruction for hip-knee-ankle axis deformity assessment\", Y. K. Ben-Zikri, Z. Yaniv, K. Baum, C. A. Linte, *Computer Methods in Biomechanics and Biomedical Engineering: Imaging & Visualization*, DOI:10.1080/21681163.2018.1537859.\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import SimpleITK as sitk\n",
    "import numpy as np\n",
    "import os.path\n",
    "import copy\n",
    "\n",
    "%matplotlib widget\n",
    "import gui\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "# utility method that either downloads data from the Girder repository or\n",
    "# if already downloaded returns the file name for reading from disk (cached data)\n",
    "%run update_path_to_download_script\n",
    "from downloaddata import fetch_data as fdata"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Loading data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Fetch all of the data associated with this example.\n",
    "data_directory = os.path.dirname(fdata(\"leg_panorama/readme.txt\"))\n",
    "\n",
    "hip_image = sitk.ReadImage(os.path.join(data_directory, \"hip.mha\"))\n",
    "knee_image = sitk.ReadImage(os.path.join(data_directory, \"knee.mha\"))\n",
    "ankle_image = sitk.ReadImage(os.path.join(data_directory, \"ankle.mha\"))\n",
    "\n",
    "gui.multi_image_display2D([hip_image, knee_image, ankle_image], figure_size=(10, 4));"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Getting to know your data\n",
    "\n",
    "As our goal is to register the images we need to identify an appropriate **similarity metric** and **transformation type**. \n",
    "\n",
    "### Similarity metric\n",
    "\n",
    "Given that we are using the same device to acquire multiple partially overlapping images, we would expect that the intensities for the same anatomical structures are the same in all images. We start by visually inspecting the images displayed above. If you hover the cursor over the images you will see the intensity value on the bottom right.\n",
    "\n",
    "We next plot the histogram for one of the images."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "intensity_profile_image = knee_image\n",
    "fig = plt.figure()\n",
    "plt.hist(sitk.GetArrayViewFromImage(intensity_profile_image).flatten(), bins=100);"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Notice that the image has a high dynamic range which is mapped to a low dynamic range when displayed, so we cannot observe all underlying intensity variations. Ideally intensity variations in x-ray images only occur when there are variations in the imaged object. In practice, we can observe non uniform intensities due to the structure of the x-ray device (e.g. absorption of photons by the x-ray anode,known as the [heel effect](https://en.wikipedia.org/wiki/Heel_effect)).\n",
    "\n",
    "In the next code cells we define a rectangular region of interest (use left mouse button, click and drag to define) in an area that is expected to have uniform intensity values (air) and plot the mean intensity per row. You can readily notice that there are intensity variations in what is expected to be a uniform region."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# The ROI we specify is in a region that is expected to have uniform intensities.\n",
    "# You can clear this ROI and specify your own in the GUI below.\n",
    "roi_list = [((396, 436), (52, 1057))]\n",
    "roi_gui = gui.ROIDataAquisition(intensity_profile_image, figure_size=(8, 4))\n",
    "roi_gui.set_rois(roi_list)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Get the region of interest (first entry in the list of ROIs)\n",
    "roi = roi_gui.get_rois()[0]\n",
    "intensities = sitk.GetArrayFromImage(\n",
    "    intensity_profile_image[roi[0][0] : roi[0][1], roi[1][0] : roi[1][1]]\n",
    ")\n",
    "\n",
    "fig, axes = plt.subplots(1, 2, sharey=True)\n",
    "fig.suptitle(\"intensity variations (mean row value)\")\n",
    "axes[0].imshow(intensities, cmap=plt.cm.Greys_r)\n",
    "axes[0].set_axis_off()\n",
    "axes[1].plot(intensities.mean(axis=1), range(intensities.shape[0]))\n",
    "axes[1].get_yaxis().set_visible(False)\n",
    "axes[1].tick_params(axis=\"x\", rotation=-90)\n",
    "plt.box(on=None)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Given our observations above, we will use **correlation** as our similarity measure and not mean squares."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Transformation type\n",
    "\n",
    "In general, the x-ray machine is modeled as a pinhole camera, with our images acquired using a front-parallel setup and the camera undergoing translation. This simplifies the general model from a homography transformation between images to a **planar translation**. For a detailed derivation see Z. Yaniv, L. Joskowicz, \"[Long Bone Panoramas from Fluoroscopic X-ray Images](https://www.ncbi.nlm.nih.gov/pubmed/14719684)\", IEEE Trans Med Imaging. 2004. \n",
    "\n",
    "While our transformation type is translation, looking at multiple triplets of images we observed that the size of overlapping regions, expected translations, has significant variability. Consequentially, we will use a heuristic **exploration-exploitation** approach to improve the robustness of our registration approach.\n",
    "\n",
    "\n",
    "\n",
    "## Registration - Exploration Step\n",
    "\n",
    "As image overlap has considerable variation we will use the ExhaustiveOptimizer to obtain several starting points, our exploration step. We then start a standard registration using these initial transformation estimates, our exploitation step. Finally we select the best transformation from the exploitation step."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Evaluate2DTranslationCorrelation:\n",
    "    \"\"\"\n",
    "    Class for evaluating the correlation value for a given set of\n",
    "    2D translations between two images. The general relationship between\n",
    "    the fixed and moving images is assumed (fixed is \"below\" the moving).\n",
    "    We use the Exhaustive optimizer to sample the possible set of translations\n",
    "    and an observer to tabulate the results.\n",
    "\n",
    "    In this class we abuse the Python dictionary by using a float\n",
    "    value as the key. This is a unique situation in which the floating\n",
    "    values are fixed (not resulting from various computations) so that we\n",
    "    can compare them for exact equality. This means they have the\n",
    "    same hash value in the dictionary.\n",
    "    \"\"\"\n",
    "\n",
    "    def __init__(\n",
    "        self,\n",
    "        metric_sampling_percentage,\n",
    "        min_row_overlap,\n",
    "        max_row_overlap,\n",
    "        column_overlap,\n",
    "        dx_step_num,\n",
    "        dy_step_num,\n",
    "    ):\n",
    "        \"\"\"\n",
    "        Args:\n",
    "            metric_sampling_percentage: Percentage of samples to use\n",
    "                                        when computing correlation.\n",
    "            min_row_overlap: Minimal number of rows that overlap between\n",
    "                             the two images.\n",
    "            max_row_overlap: Maximal number of rows that overlap between\n",
    "                             the two images.\n",
    "            column_overlap: Maximal translation in columns either in positive\n",
    "                            and negative direction.\n",
    "            dx_step_num: Number of samples in parameter space for translation along\n",
    "                         the x axis is 2*dx_step_num+1.\n",
    "            dy_step_num: Number of samples in parameter space for translation along\n",
    "                         the y axis is 2*dy_step_num+1.\n",
    "\n",
    "        \"\"\"\n",
    "        self._registration_values_dict = {}\n",
    "        self.X = None\n",
    "        self.Y = None\n",
    "        self.C = None\n",
    "        self._metric_sampling_percentage = metric_sampling_percentage\n",
    "        self._min_row_overlap = min_row_overlap\n",
    "        self._max_row_overlap = max_row_overlap\n",
    "        self._column_overlap = column_overlap\n",
    "        self._dx_step_num = dx_step_num\n",
    "        self._dy_step_num = dy_step_num\n",
    "\n",
    "    def _start_observer(self):\n",
    "        self._registration_values_dict = {}\n",
    "        self.X = None\n",
    "        self.Y = None\n",
    "        self.C = None\n",
    "\n",
    "    def _iteration_observer(self, registration_method):\n",
    "        x, y = registration_method.GetOptimizerPosition()\n",
    "        if y in self._registration_values_dict.keys():\n",
    "            self._registration_values_dict[y].append(\n",
    "                (x, registration_method.GetMetricValue())\n",
    "            )\n",
    "        else:\n",
    "            self._registration_values_dict[y] = [\n",
    "                (x, registration_method.GetMetricValue())\n",
    "            ]\n",
    "\n",
    "    def evaluate(self, fixed_image, moving_image):\n",
    "        \"\"\"\n",
    "        Assume the fixed image is lower than the moving image (e.g. fixed=knee, moving=hip).\n",
    "        The transformations map points in the fixed_image to the moving_image.\n",
    "        Args:\n",
    "            fixed_image: Image to use as fixed image in the registration.\n",
    "            moving_image: Image to use as moving image in the registration.\n",
    "        \"\"\"\n",
    "        minimal_overlap = np.array(\n",
    "            moving_image.TransformContinuousIndexToPhysicalPoint(\n",
    "                (\n",
    "                    -self._column_overlap,\n",
    "                    moving_image.GetHeight() - self._min_row_overlap,\n",
    "                )\n",
    "            )\n",
    "        ) - np.array(fixed_image.GetOrigin())\n",
    "        maximal_overlap = np.array(\n",
    "            moving_image.TransformContinuousIndexToPhysicalPoint(\n",
    "                (self._column_overlap, moving_image.GetHeight() - self._max_row_overlap)\n",
    "            )\n",
    "        ) - np.array(fixed_image.GetOrigin())\n",
    "        transform = sitk.TranslationTransform(\n",
    "            2,\n",
    "            (\n",
    "                (maximal_overlap[0] + minimal_overlap[0]) / 2.0,\n",
    "                (maximal_overlap[1] + minimal_overlap[1]) / 2.0,\n",
    "            ),\n",
    "        )\n",
    "\n",
    "        # Total number of evaluations, translations along the y axis in both directions around the initial\n",
    "        # value is 2*dy_step_num+1.\n",
    "        dy_step_length = (maximal_overlap[1] - minimal_overlap[1]) / (\n",
    "            2 * self._dy_step_num\n",
    "        )\n",
    "        dx_step_length = (maximal_overlap[0] - minimal_overlap[0]) / (\n",
    "            2 * self._dx_step_num\n",
    "        )\n",
    "        step_length = dx_step_length\n",
    "        parameter_scales = [1, dy_step_length / dx_step_length]\n",
    "\n",
    "        registration_method = sitk.ImageRegistrationMethod()\n",
    "        registration_method.SetMetricAsCorrelation()\n",
    "        registration_method.SetMetricSamplingStrategy(registration_method.RANDOM)\n",
    "        registration_method.SetMetricSamplingPercentage(\n",
    "            self._metric_sampling_percentage\n",
    "        )\n",
    "        registration_method.SetInitialTransform(transform, inPlace=True)\n",
    "        registration_method.SetOptimizerAsExhaustive(\n",
    "            numberOfSteps=[self._dx_step_num, self._dy_step_num], stepLength=step_length\n",
    "        )\n",
    "        registration_method.SetOptimizerScales(parameter_scales)\n",
    "\n",
    "        registration_method.AddCommand(\n",
    "            sitk.sitkIterationEvent,\n",
    "            lambda: self._iteration_observer(registration_method),\n",
    "        )\n",
    "        registration_method.AddCommand(sitk.sitkStartEvent, self._start_observer)\n",
    "        registration_method.Execute(fixed_image, moving_image)\n",
    "\n",
    "        # Convert the data obtained by the observer to three numpy arrays X,Y,C\n",
    "        x_lists = []\n",
    "        val_lists = []\n",
    "        for k in self._registration_values_dict.keys():\n",
    "            x_list, val_list = zip(*(sorted(self._registration_values_dict[k])))\n",
    "            x_lists.append(x_list)\n",
    "            val_lists.append(val_list)\n",
    "\n",
    "        self.X = np.array(x_lists)\n",
    "        self.C = np.array(val_lists)\n",
    "        self.Y = np.array(\n",
    "            [\n",
    "                list(self._registration_values_dict.keys()),\n",
    "            ]\n",
    "            * self.X.shape[1]\n",
    "        ).transpose()\n",
    "\n",
    "    def get_raw_data(self):\n",
    "        \"\"\"\n",
    "        Get the raw data, the translations and corresponding correlation values.\n",
    "        Returns:\n",
    "            A tuple of three numpy arrays (X,Y,C) where (X[i], Y[i]) are the translation\n",
    "            and C[i] is the correlation value for that translation.\n",
    "        \"\"\"\n",
    "        return (np.copy(self.X), np.copy(self.Y), np.copy(self.C))\n",
    "\n",
    "    def get_candidates(self, num_candidates, correlation_threshold, nms_radius=2):\n",
    "        \"\"\"\n",
    "        Get the best (most correlated, minimal correlation value) transformations\n",
    "        from the sample set.\n",
    "        Args:\n",
    "            num_candidates: Maximal number of candidates to return.\n",
    "            correlation_threshold: Minimal correlation value required for returning\n",
    "                                   a candidate.\n",
    "            nms_radius: Non-Minima (the optimizer is negating the correlation) suppression\n",
    "                        region around the local minimum.\n",
    "        Returns:\n",
    "            List of tuples containing (transform, correlation). The order of the transformations\n",
    "            in the list is based on the correlation value (best correlation is entry zero).\n",
    "        \"\"\"\n",
    "        candidates = []\n",
    "        _C = np.copy(self.C)\n",
    "        done = num_candidates - len(candidates) <= 0\n",
    "        while not done:\n",
    "            min_index = np.unravel_index(_C.argmin(), _C.shape)\n",
    "            if -_C[min_index] < correlation_threshold:\n",
    "                done = True\n",
    "            else:\n",
    "                candidates.append(\n",
    "                    (\n",
    "                        sitk.TranslationTransform(\n",
    "                            2, (self.X[min_index], self.Y[min_index])\n",
    "                        ),\n",
    "                        self.C[min_index],\n",
    "                    )\n",
    "                )\n",
    "                # None-minima suppression in the region around our minimum\n",
    "                start_nms = np.maximum(\n",
    "                    np.array(min_index) - nms_radius, np.array([0, 0])\n",
    "                )\n",
    "                # for the end coordinate we add nms_radius+1 because the slicing operator _C[],\n",
    "                # excludes the end\n",
    "                end_nms = np.minimum(\n",
    "                    np.array(min_index) + nms_radius + 1, np.array(_C.shape)\n",
    "                )\n",
    "                _C[start_nms[0] : end_nms[0], start_nms[1] : end_nms[1]] = 0\n",
    "                done = num_candidates - len(candidates) <= 0\n",
    "        return candidates\n",
    "\n",
    "\n",
    "def create_images_in_shared_coordinate_system(image_transform_list):\n",
    "    \"\"\"\n",
    "    Resample a set of images onto the same region in space (the bounding)\n",
    "    box of all images.\n",
    "    Args:\n",
    "        image_transform_list: A list of tuples each containing a transformation and an image. The transformations map the\n",
    "                              images to a shared coordinate system.\n",
    "    Returns:\n",
    "        list of images: All images are resampled into the same coordinate system and the bounding box of all images is\n",
    "                        used to define the new image extent onto which the originals are resampled. The background value\n",
    "                        for the resampled images is set to 0.\n",
    "    \"\"\"\n",
    "    pnt_list = []\n",
    "    for image, transform in image_transform_list:\n",
    "        pnt_list.append(transform.TransformPoint(image.GetOrigin()))\n",
    "        pnt_list.append(\n",
    "            transform.TransformPoint(\n",
    "                image.TransformIndexToPhysicalPoint(\n",
    "                    (image.GetWidth() - 1, image.GetHeight() - 1)\n",
    "                )\n",
    "            )\n",
    "        )\n",
    "\n",
    "    max_coordinates = np.max(pnt_list, axis=0)\n",
    "    min_coordinates = np.min(pnt_list, axis=0)\n",
    "\n",
    "    # We assume the spacing for all original images is the same and we keep it.\n",
    "    output_spacing = image_transform_list[0][0].GetSpacing()\n",
    "    # We assume the pixel type for all images is the same and we keep it.\n",
    "    output_pixelID = image_transform_list[0][0].GetPixelID()\n",
    "    # We assume the direction for all images is the same and we keep it.\n",
    "    output_direction = image_transform_list[0][0].GetDirection()\n",
    "    output_width = int(\n",
    "        np.round((max_coordinates[0] - min_coordinates[0]) / output_spacing[0])\n",
    "    )\n",
    "    output_height = int(\n",
    "        np.round((max_coordinates[1] - min_coordinates[1]) / output_spacing[1])\n",
    "    )\n",
    "    output_origin = (min_coordinates[0], min_coordinates[1])\n",
    "\n",
    "    images_in_shared_coordinate_system = []\n",
    "    for image, transform in image_transform_list:\n",
    "        images_in_shared_coordinate_system.append(\n",
    "            sitk.Resample(\n",
    "                image,\n",
    "                (output_width, output_height),\n",
    "                transform.GetInverse(),\n",
    "                sitk.sitkLinear,\n",
    "                output_origin,\n",
    "                output_spacing,\n",
    "                output_direction,\n",
    "                0.0,\n",
    "                output_pixelID,\n",
    "            )\n",
    "        )\n",
    "    return images_in_shared_coordinate_system\n",
    "\n",
    "\n",
    "def composite_images_alpha_blending(images_in_shared_coordinate_system, alpha=0.5):\n",
    "    \"\"\"\n",
    "    Composite a list of images sharing the same extent (size, origin, spacing, direction cosine).\n",
    "    Args:\n",
    "        images_in_shared_coordinate_system: A list of images sharing the same meta-information (origin, size, spacing, direction cosine).\n",
    "        We assume zero denotes background.\n",
    "    Returns:\n",
    "        SimpleITK image with pixel type sitkFloat32: alpha blending of the images.\n",
    "\n",
    "    \"\"\"\n",
    "    # Composite all of the images using alpha blending where there is overlap between two images, otherwise\n",
    "    # just paste the image values into the composite image. We assume that at most two images overlap.\n",
    "    composite_image = sitk.Cast(images_in_shared_coordinate_system[0], sitk.sitkFloat32)\n",
    "    for img in images_in_shared_coordinate_system[1:]:\n",
    "        current_image = sitk.Cast(img, sitk.sitkFloat32)\n",
    "        mask1 = sitk.Cast(composite_image != 0, sitk.sitkFloat32)\n",
    "        mask2 = sitk.Cast(current_image != 0, sitk.sitkFloat32)\n",
    "        intersection_mask = mask1 * mask2\n",
    "        composite_image = (\n",
    "            alpha * intersection_mask * composite_image\n",
    "            + (1 - alpha) * intersection_mask * current_image\n",
    "            + (mask1 - intersection_mask) * composite_image\n",
    "            + (mask2 - intersection_mask) * current_image\n",
    "        )\n",
    "    return composite_image"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We start by performing our exploration step, obtaining multiple starting point candidates.\n",
    "\n",
    "Below we also plot the similarity metric surfaces and minimal values."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "metric_sampling_percentage = 0.2\n",
    "min_row_overlap = 20\n",
    "max_row_overlap = 0.5 * hip_image.GetHeight()\n",
    "column_overlap = 0.2 * hip_image.GetWidth()\n",
    "dx_step_num = 4\n",
    "dy_step_num = 10\n",
    "\n",
    "initializer = Evaluate2DTranslationCorrelation(\n",
    "    metric_sampling_percentage,\n",
    "    min_row_overlap,\n",
    "    max_row_overlap,\n",
    "    column_overlap,\n",
    "    dx_step_num,\n",
    "    dy_step_num,\n",
    ")\n",
    "\n",
    "# Get potential starting points for the knee-hip images.\n",
    "initializer.evaluate(\n",
    "    fixed_image=sitk.Cast(knee_image, sitk.sitkFloat32),\n",
    "    moving_image=sitk.Cast(hip_image, sitk.sitkFloat32),\n",
    ")\n",
    "plotting_data = [(\"knee 2 hip\", initializer.get_raw_data())]\n",
    "k2h_candidates = initializer.get_candidates(num_candidates=4, correlation_threshold=0.5)\n",
    "\n",
    "# Get potential starting points for the ankle-knee images.\n",
    "initializer.evaluate(\n",
    "    fixed_image=sitk.Cast(ankle_image, sitk.sitkFloat32),\n",
    "    moving_image=sitk.Cast(knee_image, sitk.sitkFloat32),\n",
    ")\n",
    "plotting_data.append((\"ankle 2 knee\", initializer.get_raw_data()))\n",
    "a2k_candidates = initializer.get_candidates(num_candidates=4, correlation_threshold=0.5)\n",
    "\n",
    "# Plot the similarity metric terrain and mark the minimum with a red sphere.\n",
    "from mpl_toolkits.mplot3d import Axes3D\n",
    "\n",
    "fig = plt.figure(figsize=(10, 4))\n",
    "for i, plot_data in enumerate(plotting_data, 1):\n",
    "    ax = fig.add_subplot(1, 2, i, projection=\"3d\")\n",
    "    ax.plot_surface(*(plot_data[1]))\n",
    "    ax.set_xlabel(\"x translation\")\n",
    "    ax.set_ylabel(\"y translation\")\n",
    "    ax.set_zlabel(\"negative correlation\")\n",
    "    ax.set_title(plot_data[0])\n",
    "    min_index = np.unravel_index((plot_data[1])[2].argmin(), (plot_data[1])[2].shape)\n",
    "    ax.scatter(\n",
    "        (plot_data[1])[0][min_index],\n",
    "        (plot_data[1])[1][min_index],\n",
    "        (plot_data[1])[2][min_index],\n",
    "        marker=\"o\",\n",
    "        color=\"red\",\n",
    "    );"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# We will use the hip image coordinate system as the common coordinate system\n",
    "# and visualize the results with the transformations corresponding to the best\n",
    "# similarity metric values.\n",
    "knee2hip_transform = k2h_candidates[0][0]\n",
    "ankle2knee_transform = a2k_candidates[0][0]\n",
    "ankle2hip_transform = sitk.CompositeTransform(\n",
    "    [knee2hip_transform, ankle2knee_transform]\n",
    ")\n",
    "\n",
    "image_transform_list = [\n",
    "    (hip_image, sitk.TranslationTransform(2)),\n",
    "    (knee_image, knee2hip_transform),\n",
    "    (ankle_image, ankle2hip_transform),\n",
    "]\n",
    "composite_image = composite_images_alpha_blending(\n",
    "    create_images_in_shared_coordinate_system(image_transform_list)\n",
    ")\n",
    "\n",
    "gui.multi_image_display2D([composite_image], figure_size=(4, 8))\n",
    "print(f\"knee2hip_correlation: {k2h_candidates[0][1]:.2f}\")\n",
    "print(f\"ankle2hip_correlation: {a2k_candidates[0][1]:.2f}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Registration - Exploitation Step\n",
    "\n",
    "Now that we have a set of good (from a similarity metric standpoint) initial estimates for the transformation we will refine them using standard GradientDescent based registration.\n",
    "The final transformations are those that correspond to the best similarity metric values."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def final_registration(fixed_image, moving_image, initial_mutable_transformations):\n",
    "    \"\"\"\n",
    "    Register the two images using multiple starting transformations.\n",
    "    Args:\n",
    "        fixed_image (SimpleITK image): Estimated transformation maps points from this image to the\n",
    "                                       moving_image.\n",
    "        moving_image (SimpleITK image): Estimated transformation maps points from the fixed image to\n",
    "                                        this image.\n",
    "       initial_mutable_transformations (iterable, list like): Set of initial transformations, these will\n",
    "                                                              be modified in place.\n",
    "    \"\"\"\n",
    "    registration_method = sitk.ImageRegistrationMethod()\n",
    "    registration_method.SetMetricAsCorrelation()\n",
    "    registration_method.SetMetricSamplingStrategy(registration_method.RANDOM)\n",
    "    registration_method.SetMetricSamplingPercentage(0.2)\n",
    "    registration_method.SetOptimizerAsGradientDescent(\n",
    "        learningRate=1.0, numberOfIterations=200\n",
    "    )\n",
    "    registration_method.SetOptimizerScalesFromPhysicalShift()\n",
    "\n",
    "    def reg(transform):\n",
    "        registration_method.SetInitialTransform(transform)\n",
    "        registration_method.Execute(fixed_image, moving_image)\n",
    "        return registration_method.GetMetricValue()\n",
    "\n",
    "    final_values = [reg(transform) for transform in initial_mutable_transformations]\n",
    "    return list(zip(initial_mutable_transformations, final_values))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Copy the initial transformations for use in the final registration\n",
    "initial_transformation_list_k2h = [\n",
    "    sitk.TranslationTransform(t) for t, corr in k2h_candidates\n",
    "]\n",
    "initial_transformation_list_a2k = [\n",
    "    sitk.TranslationTransform(t) for t, corr in a2k_candidates\n",
    "]\n",
    "\n",
    "# Perform the final registration\n",
    "k2h_final = final_registration(\n",
    "    fixed_image=sitk.Cast(knee_image, sitk.sitkFloat32),\n",
    "    moving_image=sitk.Cast(hip_image, sitk.sitkFloat32),\n",
    "    initial_mutable_transformations=initial_transformation_list_k2h,\n",
    ")\n",
    "a2k_final = final_registration(\n",
    "    fixed_image=sitk.Cast(ankle_image, sitk.sitkFloat32),\n",
    "    moving_image=sitk.Cast(knee_image, sitk.sitkFloat32),\n",
    "    initial_mutable_transformations=initial_transformation_list_a2k,\n",
    ")\n",
    "\n",
    "knee2hip = min(k2h_final, key=lambda x: x[1])\n",
    "knee2hip_transform = knee2hip[0]\n",
    "\n",
    "ankle2knee = min(a2k_final, key=lambda x: x[1])\n",
    "ankle2hip_transform = sitk.CompositeTransform([knee2hip_transform, ankle2knee[0]])\n",
    "\n",
    "image_transform_list = [\n",
    "    (hip_image, sitk.TranslationTransform(2)),\n",
    "    (knee_image, knee2hip_transform),\n",
    "    (ankle_image, ankle2hip_transform),\n",
    "]\n",
    "composite_image = composite_images_alpha_blending(\n",
    "    create_images_in_shared_coordinate_system(image_transform_list)\n",
    ")\n",
    "\n",
    "gui.multi_image_display2D([composite_image], figure_size=(4, 8))\n",
    "print(f\"knee2hip_correlation: {knee2hip[1]:.2f}\")\n",
    "print(f\"ankle2hip_correlation: {ankle2knee[1]:.2f}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Additional food for thought\n",
    "\n",
    "1. Does the best final transformation correspond to the best initial one?\n",
    "2. Find the optimal parameter space sampling for the exploration stage (consider both time and accuracy).\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Measure the angle\n",
    "\n",
    "Start from the top and mark the points in the following order:\n",
    "1. Femoral head.\n",
    "2. Femoral mid-condylar point.\n",
    "3. Center of tibial spines notch.\n",
    "4. Center of the tibial plafond."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "measurement_gui = gui.PointDataAquisition(composite_image);"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "points_top_to_bottom = measurement_gui.get_points()\n",
    "if len(points_top_to_bottom) == 4:\n",
    "    femoral_axis = np.array(points_top_to_bottom[1]) - np.array(points_top_to_bottom[0])\n",
    "    tibial_axis = np.array(points_top_to_bottom[3]) - np.array(points_top_to_bottom[2])\n",
    "    angle = np.arctan2(tibial_axis[1], tibial_axis[0]) - np.arctan2(\n",
    "        femoral_axis[1], femoral_axis[0]\n",
    "    )\n",
    "\n",
    "    print(f\"Angle is (degrees): {np.degrees(angle)}\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
